{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、定义前向、后向传播\n",
    "\n",
    "本文将用numpy实现dnn, 并测试mnist手写数字识别\n",
    "\n",
    "如果对神经网络的反向传播过程还有不清楚的，可以[0_1-全连接层、损失函数的反向传播](0_1-全连接层、损失函数的反向传播.md)\n",
    "\n",
    "网络结构如下,包括3个fc层：\n",
    "input(28\\*28)=> fc (256) => relu => fc(256) => relu => fc(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# 定义权重、神经元、梯度\n",
    "weights={}\n",
    "weights_scale=1e-3\n",
    "weights[\"W1\"]=weights_scale*np.random.randn(28*28,256)\n",
    "weights[\"b1\"]=np.zeros(256)\n",
    "weights[\"W2\"]=weights_scale*np.random.randn(256,256)\n",
    "weights[\"b2\"]=np.zeros(256)\n",
    "weights[\"W3\"]=weights_scale*np.random.randn(256,10)\n",
    "weights[\"b3\"]=np.zeros(10)\n",
    "\n",
    "nuerons={}\n",
    "gradients={}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nn.layers import fc_forward\n",
    "from nn.activations import relu_forward\n",
    "\n",
    "# 定义前向过程\n",
    "def forward(X):\n",
    "    nuerons[\"z2\"]=fc_forward(X,weights[\"W1\"],weights[\"b1\"])\n",
    "    nuerons[\"z2_relu\"]=relu_forward(nuerons[\"z2\"])\n",
    "    nuerons[\"z3\"]=fc_forward(nuerons[\"z2_relu\"],weights[\"W2\"],weights[\"b2\"])\n",
    "    nuerons[\"z3_relu\"]=relu_forward(nuerons[\"z3\"])\n",
    "    nuerons[\"y\"]=fc_forward(nuerons[\"z3_relu\"],weights[\"W3\"],weights[\"b3\"])\n",
    "    return nuerons[\"y\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nn.losses import cross_entropy_loss\n",
    "from nn.layers import fc_backward\n",
    "from nn.activations import relu_backward\n",
    "\n",
    "# 定义后向过程\n",
    "def backward(X,y_true):\n",
    "    loss,dy=cross_entropy_loss(nuerons[\"y\"],y_true)\n",
    "    gradients[\"W3\"],gradients[\"b3\"],gradients[\"z3_relu\"]=fc_backward(dy,weights[\"W3\"],nuerons[\"z3_relu\"])\n",
    "    gradients[\"z3\"]=relu_backward(gradients[\"z3_relu\"],nuerons[\"z3\"])\n",
    "    gradients[\"W2\"],gradients[\"b2\"],gradients[\"z2_relu\"]=fc_backward(gradients[\"z3\"],\n",
    "                                                                     weights[\"W2\"],nuerons[\"z2_relu\"])\n",
    "    gradients[\"z2\"]=relu_backward(gradients[\"z2_relu\"],nuerons[\"z2\"])\n",
    "    gradients[\"W1\"],gradients[\"b1\"],_=fc_backward(gradients[\"z2\"],\n",
    "                                                    weights[\"W1\"],X)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 获取精度\n",
    "def get_accuracy(X,y_true):\n",
    "    y_predict=forward(X)\n",
    "    return np.mean(np.equal(np.argmax(y_predict,axis=-1),\n",
    "                            np.argmax(y_true,axis=-1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二、加载数据\n",
    "\n",
    "mnist.pkl.gz数据源： http://deeplearning.net/data/mnist/mnist.pkl.gz   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nn.load_mnist import load_mnist_datasets\n",
    "from nn.utils import to_categorical\n",
    "train_set, val_set, test_set = load_mnist_datasets('mnist.pkl.gz')\n",
    "train_y,val_y,test_y=to_categorical(train_set[1]),to_categorical(val_set[1]),to_categorical(test_set[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.shape:(16, 784),y.shape:(16, 10)\n"
     ]
    }
   ],
   "source": [
    "# 随机选择训练样本\n",
    "train_num = train_set[0].shape[0]\n",
    "def next_batch(batch_size):\n",
    "    idx=np.random.choice(train_num,batch_size)\n",
    "    return train_set[0][idx],train_y[idx]\n",
    "\n",
    "x,y= next_batch(16)\n",
    "print(\"x.shape:{},y.shape:{}\".format(x.shape,y.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADN1JREFUeJzt3X+s3XV9x/Hn23pbtLiF4qwN1OEIaBjJirtDnQx1iEHC\nLPyxSs1MtxCrmWxjcckI+0P+cFmjE0fioimjUjZFFwHhD5xiM0cMjHFhHVC6yY8VaVMoBDbBhXKh\n7/1xv5AL3PM9l/P79v18JDfnnO/7+z3fd77pq9/vOZ9zzicyE0n1vG7cDUgaD8MvFWX4paIMv1SU\n4ZeKMvxSUYZfKsrwS0UZfqmo149yZ8tjRR7BylHuUirlWX7Oc3kwFrNuX+GPiLOAy4FlwN9l5pa2\n9Y9gJe+OM/rZpaQWt+eORa/b82V/RCwD/hb4CHASsDEiTur1+SSNVj+v+U8FHsjMhzLzOeBbwPrB\ntCVp2PoJ/zHAI/Me722WvUxEbI6ImYiYmeVgH7uTNEhDf7c/M7dm5nRmTk+xYti7k7RI/YR/H7B2\n3uNjm2WSloB+wn8HcEJEvD0ilgPnAzcOpi1Jw9bzUF9mPh8RFwLfZ26ob1tm7hpYZ5KGqq9x/sy8\nCbhpQL1IGiE/3isVZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGG\nXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJR\nfc3SGxF7gKeBF4DnM3N6EE1JAA9+8b2t9d0f/0prfSqWdayd/oebW7d9w3f/rbV+OOgr/I0PZuYT\nA3geSSPkZb9UVL/hT+AHEXFnRLRfR0maKP1e9p+Wmfsi4i3AzRHxn5l5y/wVmv8UNgMcwRv73J2k\nQenrzJ+Z+5rbA8D1wKkLrLM1M6czc3qKFf3sTtIA9Rz+iFgZEW968T7wYeDeQTUmabj6uexfDVwf\nES8+zzcz858G0pWkoes5/Jn5EPBrA+xFxTz6p7/ZWv/Rx77QWp/N5b3vPHvf9HDhUJ9UlOGXijL8\nUlGGXyrK8EtFGX6pqEF8q0/qyTNrD7XWV72uj6E8deWZXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeK\ncpxfQ/XM7767Y+3a8y7vsnW0Vr/2P+9srf9wQ+dfkl/58K7Wbds/gXB48MwvFWX4paIMv1SU4ZeK\nMvxSUYZfKsrwS0U5zq++PHvOqyZpepnP/dW2jrUTp9rH8bvZfsVZrfW33ndrX89/uPPMLxVl+KWi\nDL9UlOGXijL8UlGGXyrK8EtFdR3nj4htwDnAgcw8uVm2Cvg2cBywB9iQmU8Nr01Nqv2/92xr/YNv\naKsva912054Ptdbfernj+P1YzJn/KuCVn6a4GNiRmScAO5rHkpaQruHPzFuAJ1+xeD2wvbm/HTh3\nwH1JGrJeX/Ovzsz9zf1HgdUD6kfSiPT9hl9mJpCd6hGxOSJmImJmloP97k7SgPQa/sciYg1Ac3ug\n04qZuTUzpzNzeooVPe5O0qD1Gv4bgU3N/U3ADYNpR9KodA1/RFwD3Aa8IyL2RsQFwBbgzIi4H/hQ\n81jSEtJ1nD8zN3YonTHgXjSBXn/sMa31Xb/19db6bL7QsbZ7tn3fP73sxNb6Sm5vfwK18hN+UlGG\nXyrK8EtFGX6pKMMvFWX4paL86e7ilv3qO1rr09+8d2j7/th1f9xaP/7afx3avuWZXyrL8EtFGX6p\nKMMvFWX4paIMv1SU4ZeKcpy/uIc/enRr/TtH/3uXZ2j/+e2PP/g7HWsnbnmwddvOXwbWIHjml4oy\n/FJRhl8qyvBLRRl+qSjDLxVl+KWiHOc/zD35B+9trV//6S92eYap1uqnH3l/a312U+dZml54/Kdd\n9q1h8swvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0V1HeePiG3AOcCBzDy5WXYp8Eng8Wa1SzLzpmE1\nqXZtv71/6+e/0mXrI/ra9217j2utr90zvN/9V38Wc+a/CjhrgeVfzsx1zZ/Bl5aYruHPzFuAJ0fQ\ni6QR6uc1/4URcXdEbIuIowbWkaSR6DX8XwWOB9YB+4EvdVoxIjZHxExEzMxysMfdSRq0nsKfmY9l\n5guZeQi4Aji1Zd2tmTmdmdNTdP6Sh6TR6in8EbFm3sPzAN/SlZaYxQz1XQN8AHhzROwFPgd8ICLW\nAQnsAT41xB4lDUHX8GfmxgUWXzmEXtSjn1zyxo612Rzur9+/bUt7PYe6d/XDT/hJRRl+qSjDLxVl\n+KWiDL9UlOGXivKnu5eAQ+8/pbX++envDm3fZ957fmv9yBk/37VUeeaXijL8UlGGXyrK8EtFGX6p\nKMMvFWX4paIc518C/vKqra31k6d6/+Lsn+0/vbX+ixufaq0P9wvDGibP/FJRhl8qyvBLRRl+qSjD\nLxVl+KWiDL9UlOP8S8Apy9v/j+7n57lv+/q7WutveerWnp9bk80zv1SU4ZeKMvxSUYZfKsrwS0UZ\nfqkowy8V1XWcPyLWAlcDq5mbcXlrZl4eEauAbwPHAXuADZnZ/uVvLeiR75zcWp+KnUPb95ofPdFa\n9/v6h6/FnPmfBz6bmScB7wE+ExEnARcDOzLzBGBH81jSEtE1/Jm5PzPvau4/DewGjgHWA9ub1bYD\n5w6rSUmD95pe80fEccApwO3A6szc35QeZe5lgaQlYtHhj4gjgWuBizLzZ/NrmZnMvR+w0HabI2Im\nImZmOdhXs5IGZ1Hhj4gp5oL/jcy8rln8WESsaeprgAMLbZuZWzNzOjOnp1gxiJ4lDUDX8EdEAFcC\nuzPzsnmlG4FNzf1NwA2Db0/SsCzmK73vAz4B3BPx0pjTJcAW4B8j4gLgYWDDcFpc+rpNsf036/6h\ntd7tK7v/e+jZjrXf+N5Frdu+8+H7Wus6fHUNf2b+GIgO5TMG246kUfETflJRhl8qyvBLRRl+qSjD\nLxVl+KWi/OnuEXh21fLW+mlH/LzLMyxrrX7//97WsXbi5jtatz3UZc86fHnml4oy/FJRhl8qyvBL\nRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paL8Pv8I/MLOR1vrf7T3t1vr\nX1v7L4NsRwI880tlGX6pKMMvFWX4paIMv1SU4ZeKMvxSUV3H+SNiLXA1sBpIYGtmXh4RlwKfBB5v\nVr0kM28aVqNL2fP//XBrfe972rc/h18fYDfSnMV8yOd54LOZeVdEvAm4MyJubmpfzsy/Hl57koal\na/gzcz+wv7n/dETsBo4ZdmOShus1veaPiOOAU4Dbm0UXRsTdEbEtIo7qsM3miJiJiJlZDvbVrKTB\nWXT4I+JI4Frgosz8GfBV4HhgHXNXBl9aaLvM3JqZ05k5PcWKAbQsaRAWFf6ImGIu+N/IzOsAMvOx\nzHwhMw8BVwCnDq9NSYPWNfwREcCVwO7MvGze8jXzVjsPuHfw7UkalsW82/8+4BPAPRGxs1l2CbAx\nItYxN/y3B/jUUDqUNBSLebf/x0AsUHJMX1rC/ISfVJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWi\nDL9UlOGXijL8UlGGXyrK8EtFGX6pqMjM0e0s4nFg/u9Yvxl4YmQNvDaT2tuk9gX21qtB9vbLmflL\ni1lxpOF/1c4jZjJzemwNtJjU3ia1L7C3Xo2rNy/7paIMv1TUuMO/dcz7bzOpvU1qX2BvvRpLb2N9\nzS9pfMZ95pc0JmMJf0ScFRH/FREPRMTF4+ihk4jYExH3RMTOiJgZcy/bIuJARNw7b9mqiLg5Iu5v\nbhecJm1MvV0aEfuaY7czIs4eU29rI+KfI+K+iNgVEX/SLB/rsWvpayzHbeSX/RGxDPgJcCawF7gD\n2JiZ9420kQ4iYg8wnZljHxOOiNOBZ4CrM/PkZtkXgCczc0vzH+dRmfnnE9LbpcAz4565uZlQZs38\nmaWBc4HfZ4zHrqWvDYzhuI3jzH8q8EBmPpSZzwHfAtaPoY+Jl5m3AE++YvF6YHtzfztz/3hGrkNv\nEyEz92fmXc39p4EXZ5Ye67Fr6WssxhH+Y4BH5j3ey2RN+Z3ADyLizojYPO5mFrC6mTYd4FFg9Tib\nWUDXmZtH6RUzS0/MsetlxutB8w2/VzstM98FfAT4THN5O5Fy7jXbJA3XLGrm5lFZYGbpl4zz2PU6\n4/WgjSP8+4C18x4f2yybCJm5r7k9AFzP5M0+/NiLk6Q2twfG3M9LJmnm5oVmlmYCjt0kzXg9jvDf\nAZwQEW+PiOXA+cCNY+jjVSJiZfNGDBGxEvgwkzf78I3Apub+JuCGMfbyMpMyc3OnmaUZ87GbuBmv\nM3Pkf8DZzL3j/yDwF+PooUNfvwL8R/O3a9y9Adcwdxk4y9x7IxcARwM7gPuBHwKrJqi3vwfuAe5m\nLmhrxtTbacxd0t8N7Gz+zh73sWvpayzHzU/4SUX5hp9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4\npaL+H5OL6YVERhITAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1a62479f588>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 可视化\n",
    "import matplotlib.pyplot as plt\n",
    "digit=train_set[0][3]\n",
    "plt.imshow(np.reshape(digit,(28,28)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三、训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " epoch:0 step:0 ; loss:2.302584820875885\n",
      " train_acc:0.1875;  val_acc:0.103\n",
      "\n",
      " epoch:0 step:200 ; loss:2.3089974735813046\n",
      " train_acc:0.0625;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:400 ; loss:2.3190137162037106\n",
      " train_acc:0.0625;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:600 ; loss:2.29290016314387\n",
      " train_acc:0.1875;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:800 ; loss:2.2990879829286004\n",
      " train_acc:0.125;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:1000 ; loss:2.2969247354797817\n",
      " train_acc:0.125;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:1200 ; loss:2.307249383676819\n",
      " train_acc:0.09375;  val_acc:0.1064\n",
      "\n",
      " epoch:0 step:1400 ; loss:2.3215380862102757\n",
      " train_acc:0.03125;  val_acc:0.1064\n",
      "\n",
      " epoch:1 step:0 ; loss:2.2884130059797547\n",
      " train_acc:0.25;  val_acc:0.1064\n",
      "\n",
      " epoch:1 step:200 ; loss:1.76023258152068\n",
      " train_acc:0.34375;  val_acc:0.2517\n",
      "\n",
      " epoch:1 step:400 ; loss:1.4113708080481038\n",
      " train_acc:0.40625;  val_acc:0.3138\n",
      "\n",
      " epoch:1 step:600 ; loss:1.4484238805860425\n",
      " train_acc:0.53125;  val_acc:0.5509\n",
      "\n",
      " epoch:1 step:800 ; loss:0.4831932927037818\n",
      " train_acc:0.9375;  val_acc:0.7444\n",
      "\n",
      " epoch:1 step:1000 ; loss:0.521746944367524\n",
      " train_acc:0.84375;  val_acc:0.8234\n",
      "\n",
      " epoch:1 step:1200 ; loss:0.5975823718636631\n",
      " train_acc:0.875;  val_acc:0.8751\n",
      "\n",
      " epoch:1 step:1400 ; loss:0.39426304417143254\n",
      " train_acc:0.9375;  val_acc:0.8939\n",
      "\n",
      " epoch:2 step:0 ; loss:0.3392397455325375\n",
      " train_acc:0.9375;  val_acc:0.8874\n",
      "\n",
      " epoch:2 step:200 ; loss:0.2349061434167009\n",
      " train_acc:0.96875;  val_acc:0.9244\n",
      "\n",
      " epoch:2 step:400 ; loss:0.1642980488678663\n",
      " train_acc:0.96875;  val_acc:0.9223\n",
      "\n",
      " epoch:2 step:600 ; loss:0.18962678031295344\n",
      " train_acc:1.0;  val_acc:0.9349\n",
      "\n",
      " epoch:2 step:800 ; loss:0.1374088809322303\n",
      " train_acc:1.0;  val_acc:0.9365\n",
      "\n",
      " epoch:2 step:1000 ; loss:0.45885105735878895\n",
      " train_acc:0.96875;  val_acc:0.939\n",
      "\n",
      " epoch:2 step:1200 ; loss:0.049076886226820146\n",
      " train_acc:1.0;  val_acc:0.9471\n",
      "\n",
      " epoch:2 step:1400 ; loss:0.3464252344080918\n",
      " train_acc:0.9375;  val_acc:0.9413\n",
      "\n",
      " epoch:3 step:0 ; loss:0.2719433362166901\n",
      " train_acc:0.96875;  val_acc:0.9517\n",
      "\n",
      " epoch:3 step:200 ; loss:0.06844332074679768\n",
      " train_acc:1.0;  val_acc:0.9586\n",
      "\n",
      " epoch:3 step:400 ; loss:0.16346902137921188\n",
      " train_acc:1.0;  val_acc:0.9529\n",
      "\n",
      " epoch:3 step:600 ; loss:0.15661875582989374\n",
      " train_acc:1.0;  val_acc:0.9555\n",
      "\n",
      " epoch:3 step:800 ; loss:0.10004190054365474\n",
      " train_acc:1.0;  val_acc:0.9579\n",
      "\n",
      " epoch:3 step:1000 ; loss:0.20624793312023684\n",
      " train_acc:0.96875;  val_acc:0.9581\n",
      "\n",
      " epoch:3 step:1200 ; loss:0.016292493383161803\n",
      " train_acc:1.0;  val_acc:0.9602\n",
      "\n",
      " epoch:3 step:1400 ; loss:0.08761421046492293\n",
      " train_acc:1.0;  val_acc:0.9602\n",
      "\n",
      " epoch:4 step:0 ; loss:0.23058956036352923\n",
      " train_acc:0.9375;  val_acc:0.9547\n",
      "\n",
      " epoch:4 step:200 ; loss:0.14973880899309255\n",
      " train_acc:0.96875;  val_acc:0.9674\n",
      "\n",
      " epoch:4 step:400 ; loss:0.4563995699690676\n",
      " train_acc:0.9375;  val_acc:0.9667\n",
      "\n",
      " epoch:4 step:600 ; loss:0.03818259411193518\n",
      " train_acc:1.0;  val_acc:0.9641\n",
      "\n",
      " epoch:4 step:800 ; loss:0.18057951765239755\n",
      " train_acc:1.0;  val_acc:0.968\n",
      "\n",
      " epoch:4 step:1000 ; loss:0.05313018618481231\n",
      " train_acc:1.0;  val_acc:0.9656\n",
      "\n",
      " epoch:4 step:1200 ; loss:0.07373341371929959\n",
      " train_acc:1.0;  val_acc:0.9692\n",
      "\n",
      " epoch:4 step:1400 ; loss:0.0499225679993673\n",
      " train_acc:1.0;  val_acc:0.9696\n",
      "\n",
      " final result test_acc:0.9674;  val_acc:0.9676\n"
     ]
    }
   ],
   "source": [
    "# 初始化变量\n",
    "batch_size=32\n",
    "epoch = 3\n",
    "steps = train_num // batch_size\n",
    "lr = 0.1\n",
    "\n",
    "for e in range(epoch):\n",
    "    for s in range(steps):\n",
    "        X,y=next_batch(batch_size)\n",
    "        \n",
    "        # 前向过程\n",
    "        forward(X)\n",
    "        loss=backward(X,y)\n",
    "        \n",
    "        # 更新梯度\n",
    "        for k in [\"W1\",\"b1\",\"W2\",\"b2\",\"W3\",\"b3\"]:\n",
    "            weights[k]-=lr*gradients[k]\n",
    "        \n",
    "        if s % 500 ==0:\n",
    "            print(\"\\n epoch:{} step:{} ; loss:{}\".format(e,s,loss))\n",
    "            print(\" train_acc:{};  val_acc:{}\".format(get_accuracy(X,y),get_accuracy(val_set[0],val_y)))\n",
    "\n",
    "            \n",
    "print(\"\\n final result test_acc:{};  val_acc:{}\".\n",
    "      format(get_accuracy(test_set[0],test_y),get_accuracy(val_set[0],val_y)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADLxJREFUeJzt3X/MnXV5x/H3RS0l/DChQ5paUBjrHIQ/wDzgpmzRMRAY\nS9FsTP5gXUKsySSTxGQj7I/xx5LhMjUsGpIijWVTcYkSmoVtYuMkJFvHU1Z+2QmoNbQrLQwdRaW/\nuPbHc2Me4Tn3eTi/7vP0er+SJ88593X/uHKnn+c+5/6enm9kJpLqOa7rBiR1w/BLRRl+qSjDLxVl\n+KWiDL9UlOGXijL8UlGGXyrqLZM82PGxIk/gpEkeUirlFX7CoTwYi1l3qPBHxBXA7cAy4AuZeVvb\n+idwEu+JS4c5pKQW23Lrotcd+GV/RCwDPg9cCZwHXBcR5w26P0mTNcx7/ouBZzLz+5l5CLgHWDea\ntiSN2zDhXwM8O+/57mbZL4iIDRExGxGzhzk4xOEkjdLY7/Zn5sbMnMnMmeWsGPfhJC3SMOHfA5w5\n7/kZzTJJS8Aw4X8YWBsRZ0fE8cBHgC2jaUvSuA081JeZRyLiRuBfmRvq25SZT46sM0ljNdQ4f2be\nD9w/ol4kTZAf75WKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjD\nLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqmo\noWbpjYhdwAHgKHAkM2dG0ZSOHT/+o9/oWdt22x2t2573+T9prb/jU//ZWs8jR1rr1Q0V/sYHMvOF\nEexH0gT5sl8qatjwJ/CNiNgeERtG0ZCkyRj2Zf8lmbknIk4HHoiI/87MB+ev0PxR2ABwAicOeThJ\nozLUlT8z9zS/9wP3AhcvsM7GzJzJzJnlrBjmcJJGaODwR8RJEXHKa4+By4EnRtWYpPEa5mX/KuDe\niHhtP1/OzH8ZSVeSxi4yc2IHe2uszPfEpRM7nsbvLWve3lr/xLcf6Fm7/MTDQx37ynf9Zmv91QMH\nhtr/UrQtt/JSvhiLWdehPqkowy8VZfilogy/VJThl4oy/FJRo/hffSps/wff2VofZjjv3bN/2Fp/\n28tPDbxveeWXyjL8UlGGXyrK8EtFGX6pKMMvFWX4paIc51er405s/+q1D/7pQ2M79op7Tm1fYYL/\nHf1Y5JVfKsrwS0UZfqkowy8VZfilogy/VJThl4pynF+tDr733Nb6X51+18D7/umrh1rrb/3yfwy8\nb/XnlV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiuo7zh8Rm4Crgf2ZeX6zbCXwVeAsYBdwbWb+aHxt\nqis/+PCyse3795++ps8a/zO2Y2txV/4vAle8btnNwNbMXAtsbZ5LWkL6hj8zHwRefN3idcDm5vFm\noN+fcElTZtD3/Ksyc2/z+Dlg1Yj6kTQhQ9/wy8wEen6ZWkRsiIjZiJg9zMFhDydpRAYN/76IWA3Q\n/N7fa8XM3JiZM5k5s5wVAx5O0qgNGv4twPrm8XrgvtG0I2lS+oY/Ir4C/DvwrojYHRE3ALcBl0XE\n08DvNM8lLSF9x/kz87oepUtH3Ium0O9e9OhQ2//fqz/rWTt8a/t94uMc5x8rP+EnFWX4paIMv1SU\n4ZeKMvxSUYZfKsqv7i7u4FUXtdY/t+bOofa/+0jv2nHf/q+h9q3heOWXijL8UlGGXyrK8EtFGX6p\nKMMvFWX4paIc5y9u30XLx7r/3/unm3rW1rJtrMdWO6/8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU\n4/zFHX/hcDOr7zz009b6r/3dCz1rR4c6sobllV8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiuo7zh8R\nm4Crgf2ZeX6z7Fbgo8DzzWq3ZOb942pSg3vl6otb67MX3dFnD8taq989fHpr/ehT3+uzf3VlMVf+\nLwJXLLD8s5l5QfNj8KUlpm/4M/NB4MUJ9CJpgoZ5z39jRDwWEZsi4tSRdSRpIgYN/x3AOcAFwF7g\n071WjIgNETEbEbOHOTjg4SSN2kDhz8x9mXk0M18F7gR63lXKzI2ZOZOZM8tZMWifkkZsoPBHxOp5\nTz8EPDGadiRNymKG+r4CvB84LSJ2A38JvD8iLgAS2AV8bIw9ShqDvuHPzOsWWHzXGHrRGPzstPZx\n+uXRXu/nz7Z/uLV+No8NtX+Nj5/wk4oy/FJRhl8qyvBLRRl+qSjDLxXlV3cf4w5e8+Ohtu/31dxn\nfGG8U3xrfLzyS0UZfqkowy8VZfilogy/VJThl4oy/FJRjvMfA5b96jk9a7MX/UO/rVur//zy+a31\n5d/c3mf/mlZe+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMf5jwH7PtB7muxhv5r7c9+6rLW+lm1D\n7V/d8covFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0X1HeePiDOBu4FVQAIbM/P2iFgJfBU4C9gFXJuZ\nPxpfq+rllZUx8LbbDx5qrZ/7qd2t9SMDH1ldW8yV/wjwycw8D/h14OMRcR5wM7A1M9cCW5vnkpaI\nvuHPzL2Z+Ujz+ACwE1gDrAM2N6ttBq4ZV5OSRu9NveePiLOAC4FtwKrM3NuUnmPubYGkJWLR4Y+I\nk4GvATdl5kvza5mZzN0PWGi7DRExGxGzhzk4VLOSRmdR4Y+I5cwF/0uZ+fVm8b6IWN3UVwP7F9o2\nMzdm5kxmzixnxSh6ljQCfcMfEQHcBezMzM/MK20B1jeP1wP3jb49SeOymP/S+z7geuDxiNjRLLsF\nuA34x4i4AfghcO14WlQ/p//2noG33fLSha31o8+/MPC+Nd36hj8zHwJ6DSRfOtp2JE2Kn/CTijL8\nUlGGXyrK8EtFGX6pKMMvFeVXdy8BsaL9k5Hr3v7owPv+30Mnt9bzoB/JPlZ55ZeKMvxSUYZfKsrw\nS0UZfqkowy8VZfilohznXwqOHm0tb9x5Sc/aTe/d1brtvz37K631NTzZWtfS5ZVfKsrwS0UZfqko\nwy8VZfilogy/VJThl4pynH8JyCPtE2GfdfNPetbO/evrW7eNHacM1JOWPq/8UlGGXyrK8EtFGX6p\nKMMvFWX4paIMv1RU33H+iDgTuBtYBSSwMTNvj4hbgY8Czzer3pKZ94+rUfV29Jkf9Ky94w8m2IiW\nlMV8yOcI8MnMfCQiTgG2R8QDTe2zmfm342tP0rj0DX9m7gX2No8PRMROYM24G5M0Xm/qPX9EnAVc\nCGxrFt0YEY9FxKaIOLXHNhsiYjYiZg/j1E/StFh0+CPiZOBrwE2Z+RJwB3AOcAFzrww+vdB2mbkx\nM2cyc2Y57XPOSZqcRYU/IpYzF/wvZebXATJzX2YezcxXgTuBi8fXpqRR6xv+iAjgLmBnZn5m3vLV\n81b7EPDE6NuTNC6Ludv/PuB64PGI2NEsuwW4LiIuYG74bxfwsbF0KGksFnO3/yEgFig5pi8tYX7C\nTyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VFRk5uQOFvE8\n8MN5i04DXphYA2/OtPY2rX2BvQ1qlL29MzPftpgVJxr+Nxw8YjYzZzproMW09jatfYG9Daqr3nzZ\nLxVl+KWiug7/xo6P32Zae5vWvsDeBtVJb52+55fUna6v/JI60kn4I+KKiPhuRDwTETd30UMvEbEr\nIh6PiB0RMdtxL5siYn9EPDFv2cqIeCAinm5+LzhNWke93RoRe5pztyMiruqotzMj4lsR8Z2IeDIi\nPtEs7/TctfTVyXmb+Mv+iFgGPAVcBuwGHgauy8zvTLSRHiJiFzCTmZ2PCUfEbwEvA3dn5vnNsr8B\nXszM25o/nKdm5p9PSW+3Ai93PXNzM6HM6vkzSwPXAH9Mh+eupa9r6eC8dXHlvxh4JjO/n5mHgHuA\ndR30MfUy80HgxdctXgdsbh5vZu4fz8T16G0qZObezHykeXwAeG1m6U7PXUtfnegi/GuAZ+c93810\nTfmdwDciYntEbOi6mQWsaqZNB3gOWNVlMwvoO3PzJL1uZumpOXeDzHg9at7we6NLMvPdwJXAx5uX\nt1Mp596zTdNwzaJmbp6UBWaW/rkuz92gM16PWhfh3wOcOe/5Gc2yqZCZe5rf+4F7mb7Zh/e9Nklq\n83t/x/383DTN3LzQzNJMwbmbphmvuwj/w8DaiDg7Io4HPgJs6aCPN4iIk5obMUTEScDlTN/sw1uA\n9c3j9cB9HfbyC6Zl5uZeM0vT8bmbuhmvM3PiP8BVzN3x/x7wF1300KOvXwYebX6e7Lo34CvMvQw8\nzNy9kRuAXwK2Ak8D3wRWTlFvfw88DjzGXNBWd9TbJcy9pH8M2NH8XNX1uWvpq5Pz5if8pKK84ScV\nZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qaj/B9sc50PLIFA7AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x1a616bde438>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y_true:1,y_predict:1\n"
     ]
    }
   ],
   "source": [
    "# 查看预测结果\n",
    "x,y=test_set[0][5],test_y[5]\n",
    "plt.imshow(np.reshape(x,(28,28)))\n",
    "plt.show()\n",
    "\n",
    "y_predict = np.argmax(forward([x])[0])\n",
    "\n",
    "print(\"y_true:{},y_predict:{}\".format(np.argmax(y),y_predict))\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
